# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//mlir:tblgen.bzl", "gentbl_cc_library")

package(
    default_visibility = ["//visibility:public"],
    licenses = ["notice"],
)

cc_test(
    name = "ir_tests",
    size = "small",
    srcs = glob([
        "IR/*.cpp",
        "IR/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:IR",
        "//mlir/test:TestDialect",
    ],
)

cc_test(
    name = "interface_tests",
    size = "small",
    srcs = glob([
        "Interfaces/*.cpp",
        "Interfaces/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:ControlFlowInterfaces",
        "//mlir:DLTIDialect",
        "//mlir:DataLayoutInterfaces",
        "//mlir:FuncDialect",
        "//mlir:IR",
        "//mlir:InferIntRangeInterface",
        "//mlir:InferTypeOpInterface",
        "//mlir:Parser",
    ],
)

cc_test(
    name = "support_tests",
    size = "small",
    srcs = glob([
        "Support/*.cpp",
        "Support/*.h",
    ]),
    deps = [
        "//llvm:Support",
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:Support",
    ],
)

cc_test(
    name = "pass_tests",
    size = "small",
    srcs = glob([
        "Pass/*.cpp",
        "Pass/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:Analysis",
        "//mlir:FuncDialect",
        "//mlir:IR",
        "//mlir:Pass",
    ],
)

cc_test(
    name = "rewrite_tests",
    size = "small",
    srcs = glob([
        "Rewrite/*.cpp",
        "Rewrite/*.h",
    ]),
    deps = [
        "//llvm:gtest_main",
        "//mlir:IR",
        "//mlir:Rewrite",
    ],
)

cc_test(
    name = "dialect_tests",
    size = "small",
    srcs = glob([
        "Dialect/*.cpp",
        "Dialect/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:Dialect",
    ],
)

cc_test(
    name = "memref_tests",
    size = "small",
    srcs = glob([
        "Dialect/MemRef/*.cpp",
        "Dialect/MemRef/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:MemRefDialect",
    ],
)

cc_test(
    name = "quantops_tests",
    size = "small",
    srcs = glob([
        "Dialect/Quant/*.cpp",
        "Dialect/Quant/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:QuantOps",
        "//mlir:Transforms",
    ],
)

cc_test(
    name = "scf_tests",
    size = "small",
    srcs = glob([
        "Dialect/SCF/*.cpp",
        "Dialect/SCF/*.h",
    ]),
    deps = [
        "//llvm:gtest_main",
        "//mlir:FuncDialect",
        "//mlir:Parser",
        "//mlir:SCFDialect",
    ],
)

cc_test(
    name = "sparse_tensor_tests",
    size = "small",
    srcs = glob([
        "Dialect/SparseTensor/*.cpp",
        "Dialect/SparseTensor/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:SparseTensorUtils",
    ],
)

cc_test(
    name = "spirv_tests",
    size = "small",
    srcs = glob([
        "Dialect/SPIRV/*.cpp",
        "Dialect/SPIRV/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:SPIRVDeserialization",
        "//mlir:SPIRVDialect",
        "//mlir:SPIRVSerialization",
    ],
)

cc_test(
    name = "transform_dialect_tests",
    size = "small",
    srcs = glob([
        "Dialect/Transform/*.cpp",
        "Dialect/Transform/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:FuncDialect",
        "//mlir:TransformDialect",
    ],
)

cc_test(
    name = "dialect_utils_tests",
    size = "small",
    srcs = glob([
        "Dialect/Utils/*.cpp",
        "Dialect/Utils/*.h",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:DialectUtils",
    ],
)

gentbl_cc_library(
    name = "EnumsIncGen",
    tbl_outs = [
        (
            ["-gen-enum-decls"],
            "TableGen/EnumsGenTest.h.inc",
        ),
        (
            ["-gen-enum-defs"],
            "TableGen/EnumsGenTest.cpp.inc",
        ),
    ],
    tblgen = "//mlir:mlir-tblgen",
    td_file = "TableGen/enums.td",
    deps = [
        "//mlir:OpBaseTdFiles",
    ],
)

gentbl_cc_library(
    name = "PassIncGen",
    tbl_outs = [
        (
            ["-gen-pass-decls"],
            "TableGen/PassGenTest.h.inc",
        ),
    ],
    tblgen = "//mlir:mlir-tblgen",
    td_file = "TableGen/passes.td",
    deps = [
        "//mlir:PassBaseTdFiles",
        "//mlir:RewritePassBaseTdFiles",
    ],
)

cc_test(
    name = "tablegen_tests",
    size = "small",
    srcs = glob([
        "TableGen/*.cpp",
        "TableGen/*.h",
    ]) + [
        "TableGen/EnumsGenTest.cpp.inc",
        "TableGen/EnumsGenTest.h.inc",
    ],
    includes = ["TableGen/"],
    deps = [
        ":EnumsIncGen",
        ":PassIncGen",
        "//llvm:Support",
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:IR",
        "//mlir:TableGen",
        "//mlir/test:TestDialect",
    ],
)

cc_test(
    name = "transforms_test",
    size = "small",
    srcs = glob([
        "Transforms/*.cpp",
        "Transforms/*.h",
    ]),
    deps = [
        "//llvm:gtest_main",
        "//mlir:AffineAnalysis",
        "//mlir:IR",
        "//mlir:Parser",
        "//mlir:Pass",
        "//mlir:TransformUtils",
        "//mlir:Transforms",
    ],
)

cc_test(
    name = "analysis_tests",
    size = "small",
    srcs = glob([
        "Analysis/*.cpp",
        "Analysis/*.h",
        "Analysis/*/*.cpp",
        "Analysis/*/*.h",
        "Dialect/Affine/Analysis/AffineStructuresParser.*",
    ]),
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:AffineAnalysis",
        "//mlir:Analysis",
        "//mlir:Parser",
        "//mlir:IR",
    ],
)

cc_test(
    name = "conversion_tests",
    size = "small",
    srcs = glob([
        "Conversion/*.cpp",
        "Conversion/*.h",
        "Conversion/*/*.cpp",
        "Conversion/*/*.h",
    ]),
    deps = [
        "//llvm:gtest_main",
        "//mlir:ArithmeticDialect",
        "//mlir:PDLToPDLInterp",
    ],
)

cc_test(
    name = "execution_engine_tests",
    size = "small",
    srcs = glob([
        "ExecutionEngine/*.cpp",
    ]),
    tags = [
        # MSAN does not work with JIT.
        "nomsan",
    ],
    deps = [
        "//llvm:TestingSupport",
        "//llvm:gtest_main",
        "//mlir:AllPassesAndDialects",
        "//mlir:Analysis",
        "//mlir:ExecutionEngine",
        "//mlir:IR",
        "//mlir:LinalgToLLVM",
        "//mlir:MemRefToLLVM",
        "//mlir:mlir_c_runner_utils",
        "//mlir:mlir_runner_utils",
    ],
)
